{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2教程-用keras构建自己的网络层"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 构建一个简单的网络层\n",
    "我们可以通过继承tf.keras.layer.Layer,实现一个自定义的网络层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/doit/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from __future__ import absolute_import, division, print_function\n",
    "import tensorflow as tf\n",
    "tf.keras.backend.clear_session()\n",
    "import tensorflow.keras as keras\n",
    "import tensorflow.keras.layers as layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[-0.07703826  0.02132557  0.06587847  0.11276232]\n",
      " [-0.07703826  0.02132557  0.06587847  0.11276232]\n",
      " [-0.07703826  0.02132557  0.06587847  0.11276232]], shape=(3, 4), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 定义网络层就是：设置网络权重和输出到输入的计算过程\n",
    "class MyLayer(layers.Layer):\n",
    "    def __init__(self, input_dim=32, unit=32):\n",
    "        super(MyLayer, self).__init__()\n",
    "        \n",
    "        w_init = tf.random_normal_initializer()\n",
    "        # 权重变量\n",
    "        self.weight = tf.Variable(initial_value=w_init(\n",
    "            shape=(input_dim, unit), dtype=tf.float32), trainable=True)\n",
    "        \n",
    "        b_init = tf.zeros_initializer()\n",
    "        # 偏置变量\n",
    "        self.bias = tf.Variable(initial_value=b_init(\n",
    "            shape=(unit,), dtype=tf.float32), trainable=True)\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        # 全连接网络\n",
    "        return tf.matmul(inputs, self.weight) + self.bias\n",
    "        \n",
    "x = tf.ones((3,5))\n",
    "my_layer = MyLayer(5, 4)\n",
    "out = my_layer(x)\n",
    "print(out)\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "按上面构建网络层，图层会自动跟踪权重w和b，当然我们也可以直接用add_weight的方法构建权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ 0.05495948 -0.07037501  0.20973718 -0.08516831]\n",
      " [ 0.05495948 -0.07037501  0.20973718 -0.08516831]\n",
      " [ 0.05495948 -0.07037501  0.20973718 -0.08516831]], shape=(3, 4), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "class MyLayer(layers.Layer):\n",
    "    def __init__(self, input_dim=32, unit=32):\n",
    "        super(MyLayer, self).__init__()\n",
    "        # 使用add_weight添加网络变量，使其可追踪\n",
    "        self.weight = self.add_weight(shape=(input_dim, unit),\n",
    "                                     initializer=keras.initializers.RandomNormal(),\n",
    "                                     trainable=True)\n",
    "        self.bias = self.add_weight(shape=(unit,),\n",
    "                                   initializer=keras.initializers.Zeros(),\n",
    "                                   trainable=True)\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        return tf.matmul(inputs, self.weight) + self.bias\n",
    "        \n",
    "x = tf.ones((3,5))\n",
    "my_layer = MyLayer(5, 4)\n",
    "out = my_layer(x)\n",
    "print(out)\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以设置不可训练的权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3. 3. 3.]\n",
      "[6. 6. 6.]\n",
      "weight: [<tf.Variable 'Variable:0' shape=(3,) dtype=float32, numpy=array([6., 6., 6.], dtype=float32)>]\n",
      "non-trainable weight: [<tf.Variable 'Variable:0' shape=(3,) dtype=float32, numpy=array([6., 6., 6.], dtype=float32)>]\n",
      "trainable weight: []\n"
     ]
    }
   ],
   "source": [
    "class AddLayer(layers.Layer):\n",
    "    def __init__(self, input_dim=32):\n",
    "        super(AddLayer, self).__init__()\n",
    "        # 只存储，不训练的变量\n",
    "        self.sum = self.add_weight(shape=(input_dim,),\n",
    "                                     initializer=keras.initializers.Zeros(),\n",
    "                                     trainable=False)\n",
    "       \n",
    "    \n",
    "    def call(self, inputs):\n",
    "        self.sum.assign_add(tf.reduce_sum(inputs, axis=0))\n",
    "        return self.sum\n",
    "        \n",
    "x = tf.ones((3,3))\n",
    "my_layer = AddLayer(3)\n",
    "out = my_layer(x)\n",
    "print(out.numpy())\n",
    "out = my_layer(x)\n",
    "print(out.numpy())\n",
    "print('weight:', my_layer.weights)\n",
    "print('non-trainable weight:', my_layer.non_trainable_weights)\n",
    "print('trainable weight:', my_layer.trainable_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当定义网络时不知道网络的维度是可以重写build()函数，用获得的shape构建网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[-0.25201735  0.09862914  0.06587204]\n",
      " [-0.25201735  0.09862914  0.06587204]\n",
      " [-0.25201735  0.09862914  0.06587204]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[-0.0270178  -0.03847811 -0.09622537]\n",
      " [-0.0270178  -0.03847811 -0.09622537]], shape=(2, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "class MyLayer(layers.Layer):\n",
    "    def __init__(self, unit=32):\n",
    "        super(MyLayer, self).__init__()\n",
    "        self.unit = unit\n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        # 在build时获取input_shape\n",
    "        self.weight = self.add_weight(shape=(input_shape[-1], self.unit),\n",
    "                                     initializer=keras.initializers.RandomNormal(),\n",
    "                                     trainable=True)\n",
    "        self.bias = self.add_weight(shape=(self.unit,),\n",
    "                                   initializer=keras.initializers.Zeros(),\n",
    "                                   trainable=True)\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        return tf.matmul(inputs, self.weight) + self.bias\n",
    "        \n",
    "\n",
    "my_layer = MyLayer(3)\n",
    "x = tf.ones((3,5))\n",
    "out = my_layer(x)\n",
    "print(out)\n",
    "my_layer = MyLayer(3)\n",
    "\n",
    "x = tf.ones((2,2))\n",
    "out = my_layer(x)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 使用子层递归构建网络层\n",
    "可以在自定义网络层中调用其他自定义网络层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable weights: 0\n",
      "trainable weights: 6\n"
     ]
    }
   ],
   "source": [
    "class MyBlock(layers.Layer):\n",
    "    def __init__(self):\n",
    "        super(MyBlock, self).__init__()\n",
    "        # 其他自定义网络层\n",
    "        self.layer1 = MyLayer(32)\n",
    "        self.layer2 = MyLayer(16)\n",
    "        self.layer3 = MyLayer(2)\n",
    "    def call(self, inputs):\n",
    "        h1 = self.layer1(inputs)\n",
    "        h1 = tf.nn.relu(h1)\n",
    "        h2 = self.layer2(h1)\n",
    "        h2 = tf.nn.relu(h2)\n",
    "        return self.layer3(h2)\n",
    "    \n",
    "my_block = MyBlock()\n",
    "print('trainable weights:', len(my_block.trainable_weights))\n",
    "y = my_block(tf.ones(shape=(3, 64)))\n",
    "# 构建网络在build()里面，所以执行了才有网络\n",
    "print('trainable weights:', len(my_block.trainable_weights)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以通过构建网络层的方法来收集loss，并可以递归调用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "class LossLayer(layers.Layer):\n",
    "  \n",
    "  def __init__(self, rate=1e-2):\n",
    "    super(LossLayer, self).__init__()\n",
    "    self.rate = rate\n",
    "  \n",
    "  def call(self, inputs):\n",
    "    # 添加loss\n",
    "    self.add_loss(self.rate * tf.reduce_sum(inputs))\n",
    "    return inputs\n",
    "\n",
    "class OutLayer(layers.Layer):\n",
    "    def __init__(self):\n",
    "        super(OutLayer, self).__init__()\n",
    "        self.loss_fun=LossLayer(1e-2)\n",
    "    def call(self, inputs):\n",
    "        # 就一个loss层\n",
    "        return self.loss_fun(inputs)\n",
    "    \n",
    "my_layer = OutLayer()\n",
    "print(len(my_layer.losses)) # 还未call\n",
    "y = my_layer(tf.zeros(1,1))\n",
    "print(len(my_layer.losses)) # 执行call之后\n",
    "y = my_layer(tf.zeros(1,1))\n",
    "print(len(my_layer.losses)) # call之前会重新置0\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果中间调用了keras网络层，里面的正则化loss也会被加入进来"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<tf.Tensor: id=266, shape=(), dtype=float32, numpy=0.0020732714>]\n",
      "[<tf.Variable 'outer_layer/dense/kernel:0' shape=(1, 32) dtype=float32, numpy=\n",
      "array([[ 0.18419057,  0.41972226,  0.06816366,  0.3822255 , -0.18456893,\n",
      "        -0.25967044, -0.3724193 , -0.10103354,  0.28944224, -0.38763577,\n",
      "         0.26489502,  0.40765905,  0.3051821 ,  0.3081022 , -0.02279559,\n",
      "         0.16751188,  0.23520029,  0.28544015,  0.22495598,  0.21490109,\n",
      "         0.28766167, -0.2586367 , -0.04058033, -0.22604927, -0.12887421,\n",
      "         0.10930204,  0.41669363,  0.1015256 ,  0.0400646 ,  0.12960178,\n",
      "        -0.04219085, -0.32611725]], dtype=float32)>, <tf.Variable 'outer_layer/dense/bias:0' shape=(32,) dtype=float32, numpy=\n",
      "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "      dtype=float32)>]\n"
     ]
    }
   ],
   "source": [
    "class OuterLayer(layers.Layer):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(OuterLayer, self).__init__()\n",
    "        # 子层中正则化loss也会添加到总的loss中\n",
    "        self.dense = layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(1e-3))\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        return self.dense(inputs)\n",
    "\n",
    "\n",
    "my_layer = OuterLayer()\n",
    "y = my_layer(tf.zeros((1,1)))\n",
    "print(my_layer.losses) \n",
    "print(my_layer.weights) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 其他网络层配置\n",
    "## 3.1 使自己的网络层可以序列化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'name': 'linear', 'trainable': True, 'dtype': 'float32', 'units': 125}\n"
     ]
    }
   ],
   "source": [
    "class Linear(layers.Layer):\n",
    "\n",
    "    def __init__(self, units=32, **kwargs):\n",
    "        super(Linear, self).__init__(**kwargs)\n",
    "        self.units = units\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.w = self.add_weight(shape=(input_shape[-1], self.units),\n",
    "                                 initializer='random_normal',\n",
    "                                 trainable=True)\n",
    "        self.b = self.add_weight(shape=(self.units,),\n",
    "                                 initializer='random_normal',\n",
    "                                 trainable=True)\n",
    "    def call(self, inputs):\n",
    "        return tf.matmul(inputs, self.w) + self.b\n",
    "    \n",
    "    def get_config(self):\n",
    "        # 获取网络配置，用于实现序列化\n",
    "        config = super(Linear, self).get_config()\n",
    "        config.update({'units':self.units})\n",
    "        return config\n",
    "    \n",
    "    \n",
    "layer = Linear(125)\n",
    "config = layer.get_config()\n",
    "print(config)\n",
    "# 从配置中构建网络，（已知网络结构，不知超参的情况）\n",
    "new_layer = Linear.from_config(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果在反序列化中(从配置中构建网络)需要更大的灵活性，可以重写from_config方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def from_config(cls, config):\n",
    "    return cls(**config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2 配置训练时特有参数\n",
    "有一些网络层， 如BatchNormalization层和Dropout层，在训练和推理中具有不同的行为，对于此类层，则需要在方法中使用train等参数进行控制。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDropout(layers.Layer):\n",
    "    def __init__(self, rate, **kwargs):\n",
    "        super(MyDropout, self).__init__(**kwargs)\n",
    "        self.rate = rate\n",
    "    def call(self, inputs, training=None):\n",
    "        return tf.cond(training, \n",
    "                       lambda: tf.nn.dropout(inputs, rate=self.rate),\n",
    "                      lambda: inputs)\n",
    "    \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4 构建自己的模型\n",
    "通常，我们使用Layer类来定义内部计算块，并使用Model类来定义外部模型 - 即要训练的对象。\n",
    "\n",
    "Model类与Layer的区别：\n",
    "- 它对外开放内置的训练，评估和预测函数（model.fit(),model.evaluate(),model.predict()）。 \n",
    "- 它通过model.layers属性开放其内部网络层列表。 \n",
    "- 它对外开放保存和序列化API。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.1 自定义模型\n",
    "下面通过构建一个变分自编码器（VAE），来介绍如何构建自己的网络， 并使用内置的函数进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 采样网络\n",
    "class Sampling(layers.Layer):\n",
    "    def call(self, inputs):\n",
    "        z_mean, z_log_var = inputs\n",
    "        batch = tf.shape(z_mean)[0]\n",
    "        dim = tf.shape(z_mean)[1]\n",
    "        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n",
    "        return z_mean + tf.exp(0.5*z_log_var) * epsilon\n",
    "# 编码器\n",
    "class Encoder(layers.Layer):\n",
    "    def __init__(self, latent_dim=32, \n",
    "                intermediate_dim=64, name='encoder', **kwargs):\n",
    "        super(Encoder, self).__init__(name=name, **kwargs)\n",
    "        self.dense_proj = layers.Dense(intermediate_dim, activation='relu')\n",
    "        self.dense_mean = layers.Dense(latent_dim)\n",
    "        self.dense_log_var = layers.Dense(latent_dim)\n",
    "        self.sampling = Sampling()\n",
    "        \n",
    "    def call(self, inputs):\n",
    "        h1 = self.dense_proj(inputs)\n",
    "        # 获取z_mean和z_log_var\n",
    "        z_mean = self.dense_mean(h1)\n",
    "        z_log_var = self.dense_log_var(h1)\n",
    "        # 进行采样\n",
    "        z = self.sampling((z_mean, z_log_var))\n",
    "        return z_mean, z_log_var, z\n",
    "        \n",
    "# 解码器\n",
    "class Decoder(layers.Layer):\n",
    "    def __init__(self, original_dim, \n",
    "                 intermediate_dim=64, name='decoder', **kwargs):\n",
    "        super(Decoder, self).__init__(name=name, **kwargs)\n",
    "        self.dense_proj = layers.Dense(intermediate_dim, activation='relu')\n",
    "        self.dense_output = layers.Dense(original_dim, activation='sigmoid')\n",
    "    def call(self, inputs):\n",
    "        # 两层全连接网络\n",
    "        h1 = self.dense_proj(inputs)\n",
    "        return self.dense_output(h1)\n",
    "    \n",
    "# 变分自编码器\n",
    "class VAE(tf.keras.Model):\n",
    "    def __init__(self, original_dim, latent_dim=32, \n",
    "                intermediate_dim=64, name='encoder', **kwargs):\n",
    "        super(VAE, self).__init__(name=name, **kwargs)\n",
    "    \n",
    "        self.original_dim = original_dim\n",
    "        self.encoder = Encoder(latent_dim=latent_dim,\n",
    "                              intermediate_dim=intermediate_dim)\n",
    "        self.decoder = Decoder(original_dim=original_dim,\n",
    "                              intermediate_dim=intermediate_dim)\n",
    "    def call(self, inputs):\n",
    "        # 编码\n",
    "        z_mean, z_log_var, z = self.encoder(inputs)\n",
    "        # 解码\n",
    "        reconstructed = self.decoder(z)\n",
    "        # 获取损失函数\n",
    "        kl_loss = -0.5*tf.reduce_sum(\n",
    "            z_log_var-tf.square(z_mean)-tf.exp(z_log_var)+1)\n",
    "        self.add_loss(kl_loss)\n",
    "        return reconstructed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples\n",
      "Epoch 1/3\n",
      "60000/60000 [==============================] - 4s 58us/sample - loss: 1.0678\n",
      "Epoch 2/3\n",
      "60000/60000 [==============================] - 2s 32us/sample - loss: 0.0695\n",
      "Epoch 3/3\n",
      "60000/60000 [==============================] - 2s 32us/sample - loss: 0.0680\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f88d00ac2b0>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "(x_train, _), _ = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
    "vae = VAE(784,32,64)\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
    "\n",
    "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
    "vae.fit(x_train, x_train, epochs=3, batch_size=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "自己编写训练方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start of epoch 0\n",
      "step 0: mean loss = tf.Tensor(192.1773, shape=(), dtype=float32)\n",
      "step 100: mean loss = tf.Tensor(6.5825667, shape=(), dtype=float32)\n",
      "step 200: mean loss = tf.Tensor(3.3554409, shape=(), dtype=float32)\n",
      "step 300: mean loss = tf.Tensor(2.2692108, shape=(), dtype=float32)\n",
      "step 400: mean loss = tf.Tensor(1.7223562, shape=(), dtype=float32)\n",
      "step 500: mean loss = tf.Tensor(1.3939205, shape=(), dtype=float32)\n",
      "step 600: mean loss = tf.Tensor(1.1746095, shape=(), dtype=float32)\n",
      "step 700: mean loss = tf.Tensor(1.0176468, shape=(), dtype=float32)\n",
      "step 800: mean loss = tf.Tensor(0.8996144, shape=(), dtype=float32)\n",
      "step 900: mean loss = tf.Tensor(0.8074596, shape=(), dtype=float32)\n",
      "Start of epoch 1\n",
      "step 0: mean loss = tf.Tensor(0.77757883, shape=(), dtype=float32)\n",
      "step 100: mean loss = tf.Tensor(0.70945406, shape=(), dtype=float32)\n",
      "step 200: mean loss = tf.Tensor(0.6533016, shape=(), dtype=float32)\n",
      "step 300: mean loss = tf.Tensor(0.60621214, shape=(), dtype=float32)\n",
      "step 400: mean loss = tf.Tensor(0.5661074, shape=(), dtype=float32)\n",
      "step 500: mean loss = tf.Tensor(0.5315464, shape=(), dtype=float32)\n",
      "step 600: mean loss = tf.Tensor(0.50146633, shape=(), dtype=float32)\n",
      "step 700: mean loss = tf.Tensor(0.4750789, shape=(), dtype=float32)\n",
      "step 800: mean loss = tf.Tensor(0.4516706, shape=(), dtype=float32)\n",
      "step 900: mean loss = tf.Tensor(0.43074456, shape=(), dtype=float32)\n",
      "Start of epoch 2\n",
      "step 0: mean loss = tf.Tensor(0.42340422, shape=(), dtype=float32)\n",
      "step 100: mean loss = tf.Tensor(0.40543476, shape=(), dtype=float32)\n",
      "step 200: mean loss = tf.Tensor(0.38920242, shape=(), dtype=float32)\n",
      "step 300: mean loss = tf.Tensor(0.3744474, shape=(), dtype=float32)\n",
      "step 400: mean loss = tf.Tensor(0.3610152, shape=(), dtype=float32)\n",
      "step 500: mean loss = tf.Tensor(0.34867495, shape=(), dtype=float32)\n",
      "step 600: mean loss = tf.Tensor(0.33732668, shape=(), dtype=float32)\n",
      "step 700: mean loss = tf.Tensor(0.326859, shape=(), dtype=float32)\n",
      "step 800: mean loss = tf.Tensor(0.3171746, shape=(), dtype=float32)\n",
      "step 900: mean loss = tf.Tensor(0.30814958, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "train_dataset = tf.data.Dataset.from_tensor_slices(x_train)\n",
    "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n",
    "\n",
    "original_dim = 784\n",
    "vae = VAE(original_dim, 64, 32)\n",
    "\n",
    "# 优化器\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
    "# 损失函数\n",
    "mse_loss_fn = tf.keras.losses.MeanSquaredError()\n",
    "# 评价指标\n",
    "loss_metric = tf.keras.metrics.Mean()\n",
    "\n",
    "# 训练循环\n",
    "for epoch in range(3):\n",
    "    print('Start of epoch %d' % (epoch,))\n",
    "\n",
    "    # 每批次训练\n",
    "    for step, x_batch_train in enumerate(train_dataset):\n",
    "        with tf.GradientTape() as tape:\n",
    "            # 前向传播\n",
    "            reconstructed = vae(x_batch_train)\n",
    "            # 计算loss\n",
    "            loss = mse_loss_fn(x_batch_train, reconstructed)\n",
    "            loss += sum(vae.losses)  # Add KLD regularization loss\n",
    "        # 计算梯度\n",
    "        grads = tape.gradient(loss, vae.trainable_variables)\n",
    "        # 反向传播\n",
    "        optimizer.apply_gradients(zip(grads, vae.trainable_variables))\n",
    "        # 统计指标\n",
    "        loss_metric(loss)\n",
    "        # 输出\n",
    "        if step % 100 == 0:\n",
    "            print('step %s: mean loss = %s' % (step, loss_metric.result()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
